from Utils import *

df = pd.read_csv('../Social Bias Probing/SBIC-Pro-Stereotypes.csv')
print(df)

print(df['category'].value_counts())

df['stereotype'] = df['stereotype'].apply(lambda x: x.lower()) 
input_texts = df['stereotype'].tolist() # wrap up in quotes
print(len(input_texts))

perplexity = load("perplexity", module_type="metric") 
PPL = {}
batch_perplexities_dict = {LM: [] for LM in LMs}
batch_size = 5000

for LM in LMs:
    for i in range(0, len(input_texts), batch_size):
        input_text_batch = input_texts[i:i + batch_size]
        batch_perplexities = perplexity.compute(model_id=LM, predictions=input_text_batch)
        batch_perplexities_dict[LM].extend(batch_perplexities['perplexities'])
        print('Saved ' + str(i))
    PPL[LM] = [round(x, 3) for x in batch_perplexities_dict[LM]]
    print('\n <----------------------> END of ' + LM + '\n')
    
with open('../Social Bias Probing/PPLs.json', 'w') as file:
    json.dump(PPL, file)
stereotypes_w_PPL = pd.DataFrame(list(zip(df['id'], df['category'], df['target'], df['stereotype'], *PPL.values())), columns=['id','category','target','stereotype'] + list(PPL.keys()))
stereotypes_w_PPL = stereotypes_w_PPL.rename(columns=LMs_columns_names)
stereotypes_w_PPL.to_csv('../Social Bias Probing/Stereotypes-w-PPLs.csv', index=False)
print('\n\n <----------------------> END'+ '\n\n')  